import os
import torch


class ModelLogger:
    def __init__(self, save_dir, exp_name, epoch_to_save):
        self.save_dir = save_dir
        self.exp_name = exp_name
        self.epoch_to_save = epoch_to_save
        if not os.path.exists(os.path.join(self.save_dir, self.exp_name)):
            os.makedirs(os.path.join(self.save_dir, self.exp_name), 0o0777)

    def log(self, model, epoch):
        if epoch in self.epoch_to_save:
            torch.save(model, os.path.join(self.save_dir, self.exp_name, 'epoch{:d}.pth'.format(epoch)))
            return 'ModelLogger: epoch{:d}.pth saved'.format(epoch)
        return
